spikeinterface peak detection

peak detection in spikeinterface

Author : Samuel Garcia

spikeinterface implements several method for peak detection.

peak detection can be used:

  1. as a first step for spike sorting chain
  2. as a first step for estimating motion (aka drift)

Here we will illustrate how this work and also how in conjonction of the preprocessing module we can compute this detection with or without caching the preprocessed traces on the disk.

This example will be run on neuropixel 1 and neuropixel 2 recorded by Nick Steinmetz here.

In [1]:
%load_ext autoreload
%autoreload 2
In [2]:
# %matplotlib widget
%matplotlib inline
In [4]:
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt

import spikeinterface.full as si

open dataset

In [23]:
# local folder
base_folder = Path('/mnt/data/sam/DataSpikeSorting/imposed_motion_nick')

dataset_folder = base_folder / 'dataset1/NP2'
preprocess_folder_bin = base_folder / 'dataset1_NP2_preprocessed_binary'
preprocess_folder_zarr = base_folder / 'dataset1_NP2_preprocessed_zarr'
In [21]:
# read the file
rec = si.read_spikeglx(dataset_folder)
rec
Out[21]:
SpikeGLXRecordingExtractor: 384 channels - 1 segments - 30.0kHz - 1956.954s
In [22]:
fig, ax = plt.subplots(figsize=(7, 20))
si.plot_probe_map(rec, with_channel_ids=True, ax=ax)
ax.set_ylim(-150, 200)
Out[22]:
(-150.0, 200.0)

preprocess

Here we will apply filetring + CMR

And to demonstrate the flexibility we will on work on 3 objects:

  • the lazy object rec_preprocessed
  • the cached object in binary format rec_preprocessed_cached_binary
  • the cached object in zarr format rec_preprocessed_cached_zarr

Caching to binary take Caching to zarr take

In [24]:
# global kwargs for parallel computing
job_kwargs = dict(
    n_jobs=20,
    chunk_duration='1s',
    progress_bar=True,
)
In [27]:
# create the lazy object
rec_filtered = si.bandpass_filter(rec, freq_min=300., freq_max=6000.)
rec_preprocessed = si.common_reference(rec_filtered, reference='global', operator='median')
print(rec_preprocessed)
CommonReferenceRecording: 384 channels - 1 segments - 30.0kHz - 1956.954s
In [29]:
# if not exists yet cache to binary
if preprocess_folder_bin.exists():
    rec_preprocessed_cached_binary = si.load_extractor(preprocess_folder_bin)
else:
    # cache to binary
    rec_preprocessed_cached_binary = rec_preprocessed.save(folder=preprocess_folder_bin, format='binary', **job_kwargs)
write_binary_recording with n_jobs 40  chunk_size 30000
write_binary_recording: 100%|██████████| 1957/1957 [03:41<00:00,  8.85it/s]
In [30]:
print(rec_preprocessed_cached_binary)
BinaryRecordingExtractor: 384 channels - 1 segments - 30.0kHz - 1956.954s
  file_paths: ['/mnt/data/sam/DataSpikeSorting/imposed_motion_nick/dataset1_NP2_preprocessed_binary/traces_cached_seg0.raw']
In [32]:
if preprocess_folder_zarr.exists():
    rec_preprocessed_cached_zarr = si.load_extractor(preprocess_folder_zarr)
else:
    # cache to zarr
    rec_preprocessed_cached_zarr = rec_preprocessed.save(zarr_path=preprocess_folder_zarr,  format='zarr', **job_kwargs)
write_zarr_recording with n_jobs 40  chunk_size 30000
write_zarr_recording: 100%|██████████| 1957/1957 [03:37<00:00,  9.01it/s]
Skipping field contact_plane_axes: only 1D and 2D arrays can be serialized
In [33]:
print(rec_preprocessed_cached_zarr)
ZarrRecordingExtractor: 384 channels - 1 segments - 30.0kHz - 1956.954s

show some traces

In [9]:
# plot and check spikes
si.plot_timeseries(rec_preprocessed, time_range=(100, 110), channel_ids=rec.channel_ids[50:60])
Out[9]:
<spikeinterface.widgets.timeseries.TimeseriesWidget at 0x7f003e94c880>

estimate noise

We need some estimation of the noise.

Very important : we must estimate the noise with return_scaled=False because detect_peaks() will work on raw data (int16 very often)

In [39]:
noise_levels = si.get_noise_levels(rec_preprocessed, return_scaled=False)
fig, ax = plt.subplots(figsize=(8,6))
ax.hist(noise_levels, bins=10)
ax.set_title('noise across channel')
Out[39]:
Text(0.5, 1.0, 'noise across channel')

detect peaks

we have 2 methods in spikeinetrface with are done with numba:

  • 'by_channel' : peaks are detected on each channel indepandantly
  • 'locally_exclusive' : if a units fire on several channel the best peak on the best channel is kept This is controlle by local_radius_um
In [34]:
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
In [40]:
peaks = detect_peaks(rec_preprocessed,
                     method='locally_exclusive',
                     local_radius_um=100,
                     peak_sign='neg',
                     detect_threshold=5,
                     n_shifts=5,
                     noise_levels=noise_levels,
                    **job_kwargs)
print(peaks.shape)
detect peaks: 100%|██████████| 1957/1957 [02:09<00:00, 15.09it/s]
(2531770,)

compare compute time with cached version

When we detect peak on the lazy object. Every trace chunk is loaded processed and then peak are detected on it.

When we detect peak on cached version the trace chunk is read from the save version

In [41]:
peaks = detect_peaks(rec_preprocessed_cached_binary,
                     method='locally_exclusive',
                     local_radius_um=100,
                     peak_sign='neg',
                     detect_threshold=5,
                     n_shifts=5,
                     noise_levels=noise_levels,
                    **job_kwargs)
print(peaks.shape)
detect peaks: 100%|██████████| 1957/1957 [01:30<00:00, 21.55it/s]
(2528737,)
In [42]:
peaks = detect_peaks(rec_preprocessed_cached_zarr,
                     method='locally_exclusive',
                     local_radius_um=100,
                     peak_sign='neg',
                     detect_threshold=5,
                     n_shifts=5,
                     noise_levels=noise_levels,
                    **job_kwargs)
print(peaks.shape)
detect peaks: 100%|██████████| 1957/1957 [01:28<00:00, 22.23it/s]
(2528737,)

Conlusion

Running peak detection on lazy vs cached version is an important choice.

detect_peak() is a bit faster on cahed version (1:30) than lazy version (2:00)

But the total time of save() + detect_peak() is slower (3:00 + 1:30 = 4:30) !!!

Here writing to disk is clearly a waste on time.

So the benefit of caching totally depend:

  1. on the complexity of the preprocessing chain
  2. writting disk capability
  3. how many the preprocessed recording will be cunsumed!!!

spikeinterface template matching

spikeinterface template matching

Template matching is the final step used in many tools (kilosort, spyking-circus, yass, tridesclous, hdsort...)

In this step, from a given catalogue (aka dictionnary) of template (aka atoms), algorithms explain traces as a linear sum of template plus residual noise.

We have started to implement some template matching procedure in spikeinterface.

Here a small demo and also some benchmark to compare performance of theses algos.

For this we will use a simulated with mearec dataset on 32 channel neuronexus like probe. Then we will compute the true template using the true sorting. Theses true templates will be used for diffrents method. And then we will apply comparison to ground truth procedure to estimate only this step.

In [2]:
# %matplotlib widget
%matplotlib inline
In [5]:
%load_ext autoreload
%autoreload 2
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
In [27]:
from pathlib import Path

import time

import numpy as np
import matplotlib.pyplot as plt


import spikeinterface.full as si
In [18]:
base_folder = Path('/mnt/data/sam/DataSpikeSorting/mearec_template_matching')
mearec_file = base_folder / 'recordings_collision_15cells_Neuronexus-32_1800s.h5'
wf_folder = base_folder / 'Waveforms_recording_15cells_Neuronexus-32_1800s'
rec_folder = base_folder /'Preprocessed_recording_15cells_Neuronexus-32_1800s'

open and preprocess

In [19]:
# load already cache or compute
if rec_folder.exists():
    recording = si.load_extractor(rec_folder)
else:
    recording, gt_sorting = si.read_mearec(mearec_file)
    recording = si.bandpass_filter(recording, dtype='float32')
    recording = si.common_reference(recording)
    recording = recording.save(folder=rec_folder, n_jobs=20, chunk_size=30000, progress_bar=True)
write_binary_recording with n_jobs 20  chunk_size 30000
write_binary_recording: 100%|██████████| 1800/1800 [00:16<00:00, 106.25it/s]

construct true templates

In [20]:
_, gt_sorting = si.read_mearec(mearec_file)
recording = si.load_extractor(rec_folder)
In [21]:
we = si.extract_waveforms(recording, gt_sorting, wf_folder, load_if_exists=True,
                           ms_before=2.5, ms_after=3.5, max_spikes_per_unit=500,
                           n_jobs=20, chunk_size=30000, progress_bar=True)
print(we)
extract waveforms memmap: 100%|██████████| 1800/1800 [00:11<00:00, 151.50it/s]
WaveformExtractor: 32 channels - 15 units - 1 segments
  before:75 after:105 n_per_units:500
In [23]:
metrics = si.compute_quality_metrics(we, metric_names=['snr'], load_if_exists=True)
metrics
Out[23]:
snr
#0 42.573563
#1 23.475538
#2 11.677200
#3 8.544864
#4 61.134109
#5 49.281887
#6 31.793837
#7 36.275745
#8 12.932632
#9 39.769772
#10 8.230338
#11 14.968547
#12 12.002127
#13 12.905783
#14 20.285872

run several method of template matching

A unique function is used for that find_spikes_from_templates()

In [28]:
from spikeinterface.sortingcomponents.template_matching import find_spikes_from_templates
In [25]:
# Some method need teh noise level (for internal detection)
noise_levels = si.get_noise_levels(recording, return_scaled=False)
noise_levels
Out[25]:
array([3.9969404, 3.9896376, 3.8046541, 3.5555122, 3.3091464, 3.257736 ,
       3.6201818, 3.9503036, 4.079712 , 4.2103205, 3.8557687, 3.9278026,
       3.8464408, 3.651188 , 3.4105062, 3.2170172, 3.3981993, 3.7377162,
       3.9932737, 4.1710896, 4.2710056, 4.296086 , 3.7716963, 3.7748668,
       3.6391177, 3.4687228, 3.3020885, 3.3594728, 3.6073673, 3.8444421,
       4.0852304, 4.234068 ], dtype=float32)
In [29]:
## this method support parallel computing
job_kwargs = dict(
    n_jobs=40,
    chunk_size=30000,
    progress_bar=True
)
In [30]:
# lets build dict for handling parameters
methods = {}
methods['naive'] =  ('naive', 
                    {'waveform_extractor' : we})
methods['tridesclous'] =  ('tridesclous',
                           {'waveform_extractor' : we,
                            'noise_levels' : noise_levels,
                            'num_closest' :3})
methods['circus'] =  ('circus',
                      {'waveform_extractor' : we,
                       'noise_levels' : noise_levels})
methods['circus-omp'] =  ('circus-omp',
                          {'waveform_extractor' : we,
                           'noise_levels' : noise_levels})


spikes_by_methods = {}
for name,  (method, method_kwargs) in methods.items():
    spikes = find_spikes_from_templates(recording, method=method, method_kwargs=method_kwargs, **job_kwargs)
    spikes_by_methods[name] = spikes
find spikes (naive): 100%|██████████| 1800/1800 [00:05<00:00, 314.76it/s]
find spikes (tridesclous): 100%|██████████| 1800/1800 [00:06<00:00, 268.06it/s]
[1] compute overlaps: 100%|██████████| 180/180 [00:00<00:00, 1006.84it/s]
[2] compute amplitudes: 100%|██████████| 15/15 [00:01<00:00,  7.52it/s]
find spikes (circus): 100%|██████████| 1800/1800 [00:05<00:00, 342.75it/s]
find spikes (circus-omp): 100%|██████████| 1800/1800 [00:28<00:00, 62.85it/s]
In [34]:
## the output of every method is a numpy array with a complex dtype

spikes = spikes_by_methods['tridesclous']
print(spikes.dtype)
print(spikes.shape)
print(spikes[:5])
[('sample_ind', '<i8'), ('channel_ind', '<i8'), ('cluster_ind', '<i8'), ('amplitude', '<f8'), ('segment_ind', '<i8')]
(234977,)
[( 59,  0,  4, 1., 0) (309, 21,  8, 1., 0) (371, 13,  3, 1., 0)
 (623, 30, 14, 1., 0) (713, 31, 13, 1., 0)]

check performances method by method

For this:

  1. we transform the spikes vector into a sorting object
  2. use the compare_sorter_to_ground_truth() function to compute performances
  3. plot agreement matrix
  4. plot accuracy vs snr
  5. plot collision vs similarity

Note:

  • as we provide the true template list every matrix is supposed to be squared!!! The performances are can be seen on the diagonal. A perfect matching is supposed to have only ones on the diagonal.
  • The dataset here is one the dataset used in collision paper We can also make a fine benchmark on inspecting collision.
In [33]:
# load metrics for snr on true template
metrics = we.load_extension('quality_metrics').get_metrics()
In [35]:
templates = we.get_all_templates()

comparisons = {}
for name,  (method, method_kwargs) in methods.items():
    spikes = spikes_by_methods[name]

    sorting = si.NumpySorting.from_times_labels(spikes['sample_ind'], spikes['cluster_ind'], recording.get_sampling_frequency())
    print(sorting)

    comp = si.compare_sorter_to_ground_truth(gt_sorting, sorting)
    

    fig, axs = plt.subplots(ncols=2)
    si.plot_agreement_matrix(comp, ax=axs[0])
    si.plot_sorting_performance(comp, metrics, performance_name='accuracy', metric_name='snr', ax=axs[1], color='g')
    si.plot_sorting_performance(comp, metrics, performance_name='recall', metric_name='snr', ax=axs[1], color='b')
    si.plot_sorting_performance(comp, metrics, performance_name='precision', metric_name='snr', ax=axs[1], color='r')
    axs[0].set_title(name)
    axs[1].set_ylim(0.8, 1.1)
    
    comp = si.CollisionGTComparison(gt_sorting, sorting)
    comparisons[name] = comp
    fig, ax = plt.subplots()
    si.plot_comparison_collision_by_similarity(comp, templates, figure=fig)
    fig.suptitle(name)

plt.show()
NumpySorting: 15 units - 1 segments - 30.0kHz
NumpySorting: 15 units - 1 segments - 30.0kHz
NumpySorting: 15 units - 1 segments - 30.0kHz
NumpySorting: 15 units - 1 segments - 30.0kHz

comparison of methods 2 by 2

In [36]:
names = list(comparisons.keys())
n = len(names)

for r, name0 in enumerate(names):
    for c, name1 in enumerate(names):
        if r<=c:
            continue

        fig, ax = plt.subplots()
        val0 = comparisons[name0].get_performance()['accuracy']
        val1 = comparisons[name1].get_performance()['accuracy']
        ax.scatter(val0, val1)
        ax.set_xlabel(name0)
        ax.set_ylabel(name1)
        ax.plot([0,1], [0, 1], color='k')
        ax.set_title('accuracy')
        ax.set_xlim(0.6, 1)
        ax.set_ylim(0.6, 1)

conclusion

  • tridesclous and circus-omp are clear winner for performances
  • tridesclous is the fastest
  • Improvement must be done because performances are far to be perfect!!!
In [ ]:
 

spikeinterface destripe

destripe processsing in spikeinterface

Author : Samuel Garcia

Olivier Winter has develop for IBL a standard pre-processing chain in the ibllib to clean the traces before spike sorting. See this

This procesdure is called "destripe". This procedure remove artefact that are present on all channels (common noise)

The main idea is to have this:

  1. filter
  2. align sample (phase shift
  3. remove common noise
  4. apply spatial filter and bad channel interpolation

Except step 4., all other steps are available in spikeinterface

spikeinterface.toolkit.preprocessing propose some class and function to build what we call a lazy chain of processing.

Here an example with 4 files nicely given by Oliver Winter to illustarte the spikeinterface implementation of this destripe procedure.

In [42]:
# %matplotlib widget
%matplotlib inline
In [19]:
%load_ext autoreload
%autoreload 2
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
In [20]:
from pathlib import Path

import time

import numpy as np
import matplotlib.pyplot as plt


import spikeinterface.full as si
In [21]:
base_folder = Path('/media/samuel/dataspikesorting/DataSpikeSorting/olivier_destripe/')

folder1 = base_folder / '4c04120d-523a-4795-ba8f-49dbb8d9f63a'
folder2 = base_folder / '68f06c5f-8566-4a4f-a4b1-ab8398724913'
folder3 = base_folder / '8413c5c6-b42b-4ec6-b751-881a54413628'
folder4 = base_folder / 'f74a6b9a-b8a5-4c80-9c30-7dd4cdbb48c0'
data_folders = [folder1, folder2, folder3, folder4]

Build the prprocessing chain

In spike interface we have:

  • bandpass_filter()
  • common_reference(): this remove common noise (global or local) by substraction of median (or average)
  • phase_shift(): this compensate the ADC shift across channel by applying a reverse in FFT transform.

That can be combined to get ore or less the same result than the "destripe".

Here we will compare 2 preprocessing:

  1. filter > cmr
  2. filter > phase_shift > cmr

The step 4. (kfilter) is not implemented yet but this should be done soon.

In [38]:
# lets have a function that build the chain and plot intermediate results

def preprocess_steps(rec, time_range=None, clim=(-80, 80), figsize=(15, 10)):
    
    # chain 1. : filter + cmr
    rec_filtered = si.bandpass_filter(rec, freq_min=300., freq_max=6000., dtype='float32')
    rec_cmr = si.common_reference(rec_filtered, reference='global', operator='median')
    
    # chain 2.. : filter + phase_shift + cmr
    rec_pshift = si.phase_shift(rec_filtered)
    rec_cmr2 = si.common_reference(rec_pshift, reference='global', operator='median')
    
    
    # rec
    fig, axs = plt.subplots(ncols=5, sharex=True, sharey=True, figsize=figsize)
    
    ax = axs[0]
    ax.set_title('raw')
    si.plot_timeseries(rec, ax=ax,  with_colorbar=False) # clim=clim,
    
    # filter
    
    ax = axs[1]
    ax.set_title('filtered')
    si.plot_timeseries(rec_filtered, ax=axs[1], clim=clim, with_colorbar=False)
    
    # filter + cmr
    
    # rec_preprocessed
    ax = axs[2]
    ax.set_title('filtered + cmr')
    si.plot_timeseries(rec_cmr, ax=axs[2], clim=clim, with_colorbar=False)
    
    # filter + phase_shift
    
    ax = axs[3]
    ax.set_title('filtered + phase_shift')
    si.plot_timeseries(rec_pshift, ax=ax, clim=clim, with_colorbar=False)
    
    # filtered + phase_shift + cmr
    
    ax = axs[4]
    ax.set_title('filtered + phase_shift + cmr')
    si.plot_timeseries(rec_cmr2, ax=ax, clim=clim, with_colorbar=True)

    # optionally a time range can be given
    if time_range is not None:
        ax.set_xlim(*time_range)

dataset 1

In [30]:
rec = si.read_cbin_ibl(folder1)
preprocess_steps(rec)
In [31]:
# zoom on a stripe
preprocess_steps(rec, time_range=(0.95, 0.97))

dataset 2

In [32]:
rec = si.read_cbin_ibl(folder2)
preprocess_steps(rec)
In [33]:
rec = si.read_cbin_ibl(folder2)
preprocess_steps(rec, time_range=(0.2, .3))

dataset3

In [34]:
rec = si.read_cbin_ibl(folder3)
preprocess_steps(rec)
In [35]:
rec = si.read_cbin_ibl(folder3)
preprocess_steps(rec, time_range=(0.797, .801))

dataset 4

In [39]:
rec = si.read_cbin_ibl(folder4)
preprocess_steps(rec, clim=(-50, 50))
In [41]:
preprocess_steps(rec, clim=(-50, 50), time_range=(0.368, .375))

conlusion

Here we demonstrate how to use the modular way of building a preprocessing chain directly in spikeinterface. This is particularly usefull because:

  1. the same preprocessing can be apply for diffrent sorters
  2. The preprocessing can cached in parralel using rec.save(...) in binary or zarr format
  3. Every steps can be parameterized depending the in put dataset and compute ressource available.

spikeinterface motion estimation

motion estimation in spikeinterface

In 2021,the SpikeInterface project has started to implemented sortingcomponents, a modular module for spike sorting steps.

Here is an overview or our progress integrating motion (aka drift) estimation and correction.

This notebook will be based on the open dataset from Nick Steinmetz published in 2021 "Imposed motion datasets" from Steinmetz et al. Science 2021 https://figshare.com/articles/dataset/_Imposed_motion_datasets_from_Steinmetz_et_al_Science_2021/14024495

The motion estimation is done in several modular steps:

  1. detect peaks
  2. localize peaks:
  3. estimation motion:
    • rigid or non rigid
    • "decentralize" by Erdem Varol and Julien Boussard DOI : 10.1109/ICASSP39728.2021.9414145
    • "motion cloud" by Julien Boussard (not implemented yet)

Here we will show this chain:

  • detect peak > localize peaks with "monopolar_triangulation" > estimation motion "decentralize"
In [1]:
%load_ext autoreload
%autoreload 2
In [2]:
from pathlib import Path

import spikeinterface.full as si

import numpy as np
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (20, 12)

from probeinterface.plotting import plot_probe


from spikeinterface.sortingcomponents import detect_peaks
from spikeinterface.sortingcomponents import localize_peaks
In [3]:
# local folder
base_folder = Path('/mnt/data/sam/DataSpikeSorting/imposed_motion_nick')

dataset_folder = base_folder / 'dataset1/NP1'
preprocess_folder = base_folder / 'dataset1_NP1_preprocessed'
peak_folder = base_folder / 'dataset1_NP1_peaks'

peak_folder.mkdir(exist_ok=True)
In [4]:
# global kwargs for parallel computing
job_kwargs = dict(
    n_jobs=40,
    chunk_memory='10M',
    progress_bar=True,
)
In [5]:
# read the file
rec = si.read_spikeglx(dataset_folder)
rec
Out[5]:
SpikeGLXRecordingExtractor: 384 channels - 1 segments - 30.0kHz - 1957.191s
In [11]:
fig, ax = plt.subplots()
plot_probe(rec.get_probe(), ax=ax)
ax.set_ylim(-150, 200)
Out[11]:
(-150.0, 200.0)

preprocess

This take 4 min for 30min of signals

In [7]:
rec_filtered = si.bandpass_filter(rec, freq_min=300., freq_max=6000.)
rec_preprocessed = si.common_reference(rec_filtered, reference='global', operator='median')
rec_preprocessed.save(folder=preprocess_folder, **job_kwargs)
write_binary_recording with n_jobs 40  chunk_size 13020
write_binary_recording: 100%|██████████| 4510/4510 [03:25<00:00, 21.96it/s]
Out[7]:
BinaryRecordingExtractor: 384 channels - 1 segments - 30.0kHz - 1957.191s
  file_paths: ['/mnt/data/sam/DataSpikeSorting/imposed_motion_nick/dataset1_NP1_preprocessed/traces_cached_seg0.raw']
In [5]:
# load back
rec_preprocessed = si.load_extractor(preprocess_folder)
rec_preprocessed
Out[5]:
BinaryRecordingExtractor: 384 channels - 1 segments - 30.0kHz - 1957.191s
  file_paths: ['/mnt/data/sam/DataSpikeSorting/imposed_motion_nick/dataset1_NP1_preprocessed/traces_cached_seg0.raw']
In [12]:
# plot and check spikes
si.plot_timeseries(rec_preprocessed, time_range=(100, 110), channel_ids=rec.channel_ids[50:60])
Out[12]:
<spikeinterface.widgets.timeseries.TimeseriesWidget at 0x7fc95972ae50>

estimate noise

In [14]:
noise_levels = si.get_noise_levels(rec_preprocessed, return_scaled=False)
fig, ax = plt.subplots(figsize=(8,6))
ax.hist(noise_levels, bins=np.arange(0,10, 1))
ax.set_title('noise across channel')
Out[14]:
Text(0.5, 1.0, 'noise across channel')

detect peaks

This take 1min30s

In [15]:
from spikeinterface.sortingcomponents import detect_peaks
In [16]:
peaks = detect_peaks(
    rec_preprocessed,
    method='locally_exclusive',
    local_radius_um=100,
    peak_sign='neg',
    detect_threshold=5,
    n_shifts=5,
    noise_levels=noise_levels,
    **job_kwargs,
)
np.save(peak_folder / 'peaks.npy', peaks)
detect peaks: 100%|██████████| 4510/4510 [01:31<00:00, 49.13it/s]
In [8]:
# load back
peaks = np.load(peak_folder / 'peaks.npy')
print(peaks.shape)
(4041217,)

localize peaks

We use 2 methods:

  • 'center_of_mass': 9 s
  • 'monopolar_triangulation' : 26min
In [18]:
from spikeinterface.sortingcomponents import localize_peaks
In [19]:
peak_locations = localize_peaks(
    rec_preprocessed,
    peaks,
    ms_before=0.3,
    ms_after=0.6,
    method='center_of_mass',
    method_kwargs={'local_radius_um': 100.},
    **job_kwargs,
)
np.save(peak_folder / 'peak_locations_center_of_mass.npy', peak_locations)
print(peak_locations.shape)
localize peaks: 100%|██████████| 4510/4510 [00:09<00:00, 461.01it/s]
(4041217,)
In [20]:
peak_locations = localize_peaks(
    rec_preprocessed,
    peaks,
    ms_before=0.3,
    ms_after=0.6,
    method='monopolar_triangulation',
    method_kwargs={'local_radius_um': 100., 'max_distance_um': 1000.},
    **job_kwargs,
)
np.save(peak_folder / 'peak_locations_monopolar_triangulation.npy', peak_locations)
print(peak_locations.shape)
localize peaks:   0%|          | 2/4510 [00:13<10:43:51,  8.57s/it]
In [6]:
# load back
# peak_locations = np.load(peak_folder / 'peak_locations_center_of_mass.npy')
peak_locations = np.load(peak_folder / 'peak_locations_monopolar_triangulation.npy')
print(peak_locations)
[(  18.52504101, 1783.26060082,  80.56493564, 1736.54517744)
 (  75.90387896, 4135.11490531,   1.02883473, 4001.33816608)
 ( -23.97108877, 2632.738146  ,  87.2656153 , 2632.17702833) ...
 (  40.06415842, 1977.85847864,  26.4586952 , 1091.46159133)
 (-185.47200933, 1795.53548018, 155.37976473, 3492.17984483)
 (  58.83825019, 1178.6461218 ,  82.17022322, 1253.97375113)]

 plot peak on probe

In [16]:
probe = rec_preprocessed.get_probe()

fig, ax = plt.subplots(figsize=(15, 10))
plot_probe(probe, ax=ax)
ax.scatter(peak_locations['x'], peak_locations['y'], color='k', s=1, alpha=0.002)
# ax.set_ylim(2400, 2900)
ax.set_ylim(1500, 2500)
Out[16]:
(1500.0, 2500.0)

plot peak depth vs time

In [11]:
fig, ax = plt.subplots()
x = peaks['sample_ind'] / rec_preprocessed.get_sampling_frequency()
y = peak_locations['y']
ax.scatter(x, y, s=1, color='k', alpha=0.05)
ax.set_ylim(1300, 2500)
Out[11]:
(1300.0, 2500.0)

motion estimate : rigid with decentralized

In [17]:
from spikeinterface.sortingcomponents import (
    estimate_motion,
    make_motion_histogram,
    compute_pairwise_displacement,
    compute_global_displacement
)
In [18]:
bin_um = 2
bin_duration_s=5.

motion_histogram, temporal_bins, spatial_bins = make_motion_histogram(
    rec_preprocessed,
    peaks,
    peak_locations=peak_locations, 
    bin_um=bin_um,
    bin_duration_s=bin_duration_s,
    direction='y',
    weight_with_amplitude=False,
)
print(motion_histogram.shape, temporal_bins.size, spatial_bins.size)
(392, 1960) 393 1961
In [22]:
fig, ax = plt.subplots()
extent = (temporal_bins[0], temporal_bins[-1], spatial_bins[0], spatial_bins[-1])
im = ax.imshow(
    motion_histogram.T,
    interpolation='nearest',
    origin='lower',
    aspect='auto',
    extent=extent,
)
im.set_clim(0, 15)
ax.set_ylim(1300, 2500)
ax.set_xlabel('time[s]')
ax.set_ylabel('depth[um]')
Out[22]:
Text(0, 0.5, 'depth[um]')

pariwise displacement from the motion histogram

In [23]:
pairwise_displacement = compute_pairwise_displacement(motion_histogram, bin_um, method='conv2d', )
np.save(peak_folder / 'pairwise_displacement_conv2d.npy', pairwise_displacement)
In [24]:
fig, ax = plt.subplots()
extent = (temporal_bins[0], temporal_bins[-1], temporal_bins[0], temporal_bins[-1])
# extent = None
im = ax.imshow(
    pairwise_displacement,
    interpolation='nearest',
    cmap='PiYG',
    origin='lower',
    aspect='auto',
    extent=extent,
)
im.set_clim(-40, 40)
ax.set_aspect('equal')
fig.colorbar(im)
Out[24]:
<matplotlib.colorbar.Colorbar at 0x7f48ee351eb0>

estimate motion (rigid) from the pairwise displacement

In [25]:
motion = compute_global_displacement(pairwise_displacement)
In [26]:
fig, ax = plt.subplots()
ax.plot(temporal_bins[:-1], motion)
Out[26]:
[<matplotlib.lines.Line2D at 0x7f48f6a9e5e0>]

motion estimation with one unique funtion

Internally estimate_motion() does:

  • make_motion_histogram()
  • compute_pairwise_displacement()
  • compute_global_displacement()
In [27]:
motion, temporal_bins, spatial_bins = estimate_motion(
    rec_preprocessed,
    peaks,
    peak_locations=peak_locations,
    direction='y',
    bin_duration_s=5.,
    bin_um=10.,
    method='decentralized_registration',
    method_kwargs={},
    non_rigid_kwargs=None, 
    progress_bar=True,
    verbose=True,
)
make_motion_histogram
0
compute_pairwise_displacement 0
100%|██████████| 392/392 [00:06<00:00, 63.11it/s]
compute_global_displacement 0
In [30]:
fig, ax = plt.subplots()
x = peaks['sample_ind'] / rec_preprocessed.get_sampling_frequency()
y = peak_locations['y']
ax.scatter(x, y, s=1, color='k', alpha=0.05)
ax.set_ylim(1300, 2500)


ax.plot(temporal_bins[:-1], motion + 2000, color='r')
ax.set_xlabel('times[s]')
ax.set_ylabel('motion [um]')
Out[30]:
Text(0, 0.5, 'motion [um]')

motion estimation non rigid

In [31]:
motion, temporal_bins, spatial_bins = estimate_motion(
    rec_preprocessed,
    peaks,
    peak_locations=peak_locations,
    direction='y',
    bin_duration_s=5.,
    bin_um=10.,
    method='decentralized_registration',
    method_kwargs={},
    non_rigid_kwargs=dict(bin_step_um=200),
    progress_bar=True,
    verbose=True,
)
print(motion.shape)
print(temporal_bins.shape)
make_motion_histogram
0
compute_pairwise_displacement 0
100%|██████████| 392/392 [00:06<00:00, 62.04it/s]
compute_global_displacement 0
1
compute_pairwise_displacement 1
100%|██████████| 392/392 [00:06<00:00, 63.56it/s]
compute_global_displacement 1
2
compute_pairwise_displacement 2
100%|██████████| 392/392 [00:06<00:00, 62.15it/s]
compute_global_displacement 2
3
compute_pairwise_displacement 3
100%|██████████| 392/392 [00:06<00:00, 63.35it/s]
compute_global_displacement 3
4
compute_pairwise_displacement 4
100%|██████████| 392/392 [00:06<00:00, 63.07it/s]
compute_global_displacement 4
5
compute_pairwise_displacement 5
100%|██████████| 392/392 [00:06<00:00, 63.13it/s]
compute_global_displacement 5
6
compute_pairwise_displacement 6
100%|██████████| 392/392 [00:06<00:00, 63.40it/s]
compute_global_displacement 6
7
compute_pairwise_displacement 7
100%|██████████| 392/392 [00:06<00:00, 62.72it/s]
compute_global_displacement 7
8
compute_pairwise_displacement 8
100%|██████████| 392/392 [00:06<00:00, 63.54it/s]
compute_global_displacement 8
9
compute_pairwise_displacement 9
100%|██████████| 392/392 [00:06<00:00, 63.36it/s]
compute_global_displacement 9
10
compute_pairwise_displacement 10
100%|██████████| 392/392 [00:06<00:00, 63.22it/s]
compute_global_displacement 10
11
compute_pairwise_displacement 11
100%|██████████| 392/392 [00:06<00:00, 63.36it/s]
compute_global_displacement 11
12
compute_pairwise_displacement 12
100%|██████████| 392/392 [00:06<00:00, 63.48it/s]
compute_global_displacement 12
13
compute_pairwise_displacement 13
100%|██████████| 392/392 [00:06<00:00, 63.52it/s]
compute_global_displacement 13
14
compute_pairwise_displacement 14
100%|██████████| 392/392 [00:06<00:00, 62.80it/s]
compute_global_displacement 14
15
compute_pairwise_displacement 15
100%|██████████| 392/392 [00:06<00:00, 62.80it/s]
compute_global_displacement 15
16
compute_pairwise_displacement 16
100%|██████████| 392/392 [00:06<00:00, 63.50it/s]
compute_global_displacement 16
17
compute_pairwise_displacement 17
100%|██████████| 392/392 [00:06<00:00, 64.27it/s]
compute_global_displacement 17
18
compute_pairwise_displacement 18
100%|██████████| 392/392 [00:06<00:00, 64.51it/s]
compute_global_displacement 18
19
compute_pairwise_displacement 19
100%|██████████| 392/392 [00:06<00:00, 64.09it/s]
compute_global_displacement 19
(392, 20)
(393,)
In [32]:
fs = rec_preprocessed.get_sampling_frequency()

fig, ax = plt.subplots()
ax.scatter(peaks['sample_ind'] / fs, peak_locations['y'], color='k', s=0.1, alpha=0.05)


for i, s_bins in enumerate(spatial_bins):
    # several motion vector
    ax.plot(temporal_bins[:-1], motion[:, i] + spatial_bins[i], color='r')

ax.set_ylim(1300, 2500)
ax.set_xlabel('times[s]')
ax.set_ylabel('motion [um]')
Out[32]:
Text(0, 0.5, 'motion [um]')
In [ ]:
 

Collision paper spike sorting performance

Spike sorting performance against spike collisions (figure 2-3-5)

In this notebook, we describe how to generate the figures for all the sudies, i.e. for all rate and correlation levels, in a systematic manner. However, while by default the figures were saved as .pdf, here we will modify the ranges of rates and correlations to display only a single figures. Feel free to modify the scripts in order to display only a single figures

In [1]:
import numpy as np
from pathlib import Path

import matplotlib.pyplot as plt
from matplotlib import gridspec

import MEArec as mr
import spikeinterface.full as si
In [2]:
study_base_folder = Path('../data/study/')

Plot global spike sorting performance (Figure 2)

In [1]:
res = {}

rate_levels = [5]
corr_levels = [0]

for rate_level in rate_levels:
    for corr_level in corr_levels:

        fig = plt.figure(figsize=(15,5))
        gs = gridspec.GridSpec(2, 3, figure=fig)

        study_folder = study_base_folder / f'20cells_5noise_{corr_level}corr_{rate_level}rate_Neuronexus-32'
        
        study = si.GroundTruthStudy(study_folder)
        study.run_comparisons(exhaustive_gt=True)

        ax_1 = plt.subplot(gs[0, 0])
        ax_2 = plt.subplot(gs[0, 1:])
        ax_3 = plt.subplot(gs[1, 1:])
        ax_4 = plt.subplot(gs[1, 0])

        for ax in [ax_1, ax_2, ax_3, ax_4]:
            ax.spines['top'].set_visible(False)
            ax.spines['right'].set_visible(False)

        ax_2.tick_params(labelbottom=False)
        ax_2.set_xlabel('')

        si.plot_gt_study_run_times(study, ax=ax_1)
        si.plot_gt_study_unit_counts(study, ax=ax_2)
        si.plot_gt_study_performances_averages(study, ax=ax_3)
        si.plot_gt_study_performances_by_template_similarity(study, ax=ax_4)

        plt.tight_layout()

Plot collision recall as function of the lags (Figure 3)

In [2]:
for rate_level in rate_levels:
    for corr_level in corr_levels:
        study_folder = f'../data/study/20cells_5noise_{corr_level}corr_{rate_level}rate_Neuronexus-32'
        res[(rate_level, corr_level)] = si.CollisionGTStudy(study_folder)
        res[(rate_level, corr_level)].run_comparisons(exhaustive_gt=True, collision_lag=2, nbins=11)

        for rec_name in res[(rate_level, corr_level)].rec_names:
            res[(rate_level, corr_level)].compute_waveforms(rec_name)

        si.plot_study_comparison_collision_by_similarity(res[(rate_level, corr_level)], 
                                                         show_legend=False, ylim=(0.4, 1))
        plt.tight_layout()
Python 3.6.9 (default, Jan 26 2021, 15:33:00) 
[GCC 8.4.0] does not support parallel processing
Python 3.6.9 (default, Jan 26 2021, 15:33:00) 
[GCC 8.4.0] does not support parallel processing
Python 3.6.9 (default, Jan 26 2021, 15:33:00) 
[GCC 8.4.0] does not support parallel processing
Python 3.6.9 (default, Jan 26 2021, 15:33:00) 
[GCC 8.4.0] does not support parallel processing
Python 3.6.9 (default, Jan 26 2021, 15:33:00) 
[GCC 8.4.0] does not support parallel processing
Python 3.6.9 (default, Jan 26 2021, 15:33:00) 
[GCC 8.4.0] does not support parallel processing
Python 3.6.9 (default, Jan 26 2021, 15:33:00) 
[GCC 8.4.0] does not support parallel processing
Python 3.6.9 (default, Jan 26 2021, 15:33:00) 
[GCC 8.4.0] does not support parallel processing
Python 3.6.9 (default, Jan 26 2021, 15:33:00) 
[GCC 8.4.0] does not support parallel processing
Python 3.6.9 (default, Jan 26 2021, 15:33:00) 
[GCC 8.4.0] does not support parallel processing

Plot collision recall as function of the lag and/or cosine similarity (supplementary figures)

In [3]:
gs = gridspec.GridSpec(len(rate_levels), len(corr_levels))
for i, rate_level in enumerate(rate_levels):
    for j, corr_level in enumerate(corr_levels):
        
        ax = plt.subplot(gs[i, j])

        if i == 0 and j == 0:
            show_legend = True
        else:
            show_legend = False
        si.plot_study_comparison_collision_by_similarity_range(res[(rate_level, corr_level)], show_legend=show_legend, similarity_range=[0.5, 1], ax=ax, ylim=(0.3, 1))

        ax.set_title(f'Rate {rate_level} Hz, Corr {corr_level}' )
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)

        if rate_level != rate_levels[-1]:
            ax.tick_params(labelbottom=False)
            ax.set_xlabel('')
        else:
            ax.set_xlabel('lags (ms)')

        if corr_level != corr_levels[0]:
            ax.tick_params(labelleft=False)
            ax.set_ylabel('')
        else:
            ax.set_ylabel('collision accuracy')
In [4]:
gs = gridspec.GridSpec(len(rate_levels), len(corr_levels))
for i, rate_level in enumerate(rate_levels):
    for j, corr_level in enumerate(corr_levels):
        
        ax = plt.subplot(gs[i, j])

        if i == 0 and j == 0:
            show_legend = True
        else:
            show_legend = False
        si.plot_study_comparison_collision_by_similarity_ranges(res[(rate_level, corr_level)], show_legend=show_legend, ax=ax)

        ax.set_title(f'Rate {rate_level} Hz, Corr {corr_level}' )
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)

        if rate_level != rate_levels[-1]:
            ax.tick_params(labelbottom=False)
            ax.set_xlabel('')
        else:
            ax.set_xlabel('similarity')

        if corr_level != corr_levels[0]:
            ax.tick_params(labelleft=False)
            ax.set_ylabel('')
        else:
            ax.set_ylabel('collision accuracy')

Plot average collision recall over multiple conditions, as function of the lags (Figure 5)

In [9]:
rate_levels = [5,10,15]
corr_levels = [0, 0.1, 0.2]

gs = gridspec.GridSpec(1, 2)
ax = plt.subplot(gs[0, 0])
curves = {}
for i, rate_level in enumerate(rate_levels):
    for j, corr_level in enumerate(corr_levels):
        
        study_folder = f'../data/study/20cells_5noise_{corr_level}corr_{rate_level}rate_Neuronexus-32'
        res[(rate_level, corr_level)] = si.CollisionGTStudy(study_folder)
        res[(rate_level, corr_level)].run_comparisons(exhaustive_gt=True, collision_lag=2, nbins=11)
        res[(rate_level, corr_level)].precompute_scores_by_similarities()
        
        for sorter_name in res[(rate_level, corr_level)].sorter_names:
            data = res[(rate_level, corr_level)].get_mean_over_similarity_range([0.5, 1], sorter_name)
            if not sorter_name in curves:
                curves[sorter_name] = [data]
            else:
                curves[sorter_name] += [data]

lags = res[(rate_level, corr_level)].get_lags()
for sorter_name in res[(rate_level, corr_level)].sorter_names:
    curves[sorter_name] = np.array(curves[sorter_name])
    mean_sorter = curves[sorter_name].mean(0)
    std_sorter = curves[sorter_name].std(0)
    ax.plot(lags[:-1] + (lags[1]-lags[0]) / 2, mean_sorter, label=sorter_name)
    ax.fill_between(lags[:-1] + (lags[1]-lags[0]) / 2, mean_sorter-std_sorter,mean_sorter+std_sorter, alpha=0.2)

ax.legend()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_xlabel('lag (ms)')
ax.set_ylabel('collision accuracy')
Out[9]:
Text(0, 0.5, 'collision accuracy')

Plotting the average collision recall over multiple conditions, as function of the similarity

In [5]:
rate_levels = [5,10,15]
corr_levels = [0, 0.1, 0.2]
res = {}
gs = gridspec.GridSpec(1, 2)
ax = plt.subplot(gs[0, 0])
curves = {}
similarity_ranges = np.linspace(-0.4, 1, 8)
for i, rate_level in enumerate(rate_levels):
    for j, corr_level in enumerate(corr_levels):
        
        study_folder = f'../data/study/20cells_5noise_{corr_level}corr_{rate_level}rate_Neuronexus-32'
        res[(rate_level, corr_level)] = si.CollisionGTStudy(study_folder)
        res[(rate_level, corr_level)].run_comparisons(exhaustive_gt=True, collision_lag=2, nbins=11)
        res[(rate_level, corr_level)].precompute_scores_by_similarities()
        
        for sorter_name in res[(rate_level, corr_level)].sorter_names:

            all_similarities = res[(rate_level, corr_level)].all_similarities[sorter_name]
            all_recall_scores = res[(rate_level, corr_level)].all_recall_scores[sorter_name]

            order = np.argsort(all_similarities)
            all_similarities = all_similarities[order]
            all_recall_scores = all_recall_scores[order, :]

            mean_recall_scores = []
            std_recall_scores = []
            for k in range(similarity_ranges.size - 1):
                cmin, cmax = similarity_ranges[k], similarity_ranges[k + 1]
                amin, amax = np.searchsorted(all_similarities, [cmin, cmax])
                value = np.mean(all_recall_scores[amin:amax])
                mean_recall_scores += [np.nan_to_num(value)]

            xaxis = np.diff(similarity_ranges)/2 + similarity_ranges[:-1]

            data = mean_recall_scores
            if not sorter_name in curves:
                curves[sorter_name] = [data]
            else:
                curves[sorter_name] += [data]

for sorter_name in res[(rate_level, corr_level)].sorter_names:
    curves[sorter_name] = np.array(curves[sorter_name])
    mean_sorter = curves[sorter_name].mean(0)
    std_sorter = curves[sorter_name].std(0)
    ax.plot(xaxis, mean_sorter, label=sorter_name)
    ax.fill_between(xaxis, mean_sorter-std_sorter,mean_sorter+std_sorter, alpha=0.2)


ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_xlabel('cosine similarity')
#ax.set_ylabel('collision accuracy')
#ax.set_yticks([])

plt.tight_layout()
/home/cure/.local/lib/python3.6/site-packages/numpy/core/fromnumeric.py:3373: RuntimeWarning: Mean of empty slice.
  out=out, **kwargs)
/home/cure/.local/lib/python3.6/site-packages/numpy/core/_methods.py:170: RuntimeWarning: invalid value encountered in double_scalars
  ret = ret.dtype.type(ret / rcount)
/home/cure/.local/lib/python3.6/site-packages/numpy/core/fromnumeric.py:3373: RuntimeWarning: Mean of empty slice.
  out=out, **kwargs)
/home/cure/.local/lib/python3.6/site-packages/numpy/core/_methods.py:170: RuntimeWarning: invalid value encountered in double_scalars
  ret = ret.dtype.type(ret / rcount)
/home/cure/.local/lib/python3.6/site-packages/numpy/core/fromnumeric.py:3373: RuntimeWarning: Mean of empty slice.
  out=out, **kwargs)
/home/cure/.local/lib/python3.6/site-packages/numpy/core/_methods.py:170: RuntimeWarning: invalid value encountered in double_scalars
  ret = ret.dtype.type(ret / rcount)

Collision paper simulated recordings

Simulated recordings overview (figure 1)

This notebook reproduces Figure 1 of the manuscript: "How do spike collisions affect spike sorting performance?"

To run this notebook, you first need to run the generate_recordings.ipynb notebook.

In [1]:
import shutil
import sys
from pathlib import Path

import numpy as np
import scipy.spatial

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.colors as colors

import MEArec as mr
import spikeinterface.full as si


my_cmap = plt.get_cmap('winter')
cNorm  = colors.Normalize(vmin=0, vmax=1)
scalarMap = plt.cm.ScalarMappable(norm=cNorm, cmap=my_cmap)
In [ ]:
sys.path.append("../utils")

from generation_utils import generation_params
from study_utils import generate_study
In [ ]:
recordings_folder = Path('../data/recordings/'')
In [2]:
# define some parameters

nb_traces = 10 # for panel I
window_ms = 20 #for CC plots
bin_ms = 0.2 # for CC plots
n_cell = 20 #
lag_time = generation_params['lag_time']*1000
corr_level = 0 # to select the appropriate recording if several (run generation first)
rate_level = 5 # to select the appropriate recording if several (run generation first)
In [8]:
# We use the plotting.py script to ease the creation of figures with several panels. 
figA, axA = plt.subplots()

# We load the file
rec_file = recordings_folder / f'rec0_20cells_5noise_{corr_level}corr_{rate_level}rate_Neuronexus-32.h5'

mearec_object = mr.load_recordings(rec_file)
rec = si.MEArecRecordingExtractor(rec_file)
sorting_gt = si.MEArecSortingExtractor(rec_file)

waveforms_path = Path('.') / 'tmp'
waveforms_path.mkdir(exist_ok=True)

waveforms = si.extract_waveforms(rec, sorting_gt, waveforms_path, ms_before=3, ms_after=3)

original_templates = waveforms.get_all_templates()
snrs = np.array([i for i in si.compute_snrs(waveforms).values()])
rates = np.array([i for i in si.compute_firing_rate(waveforms).values()])


## Plotting the probe layout and the cell positions
si.plot_unit_localization(waveforms, ax=axA)
axA.set_ylabel('y (um)')
axA.set_xlabel('x (um)')
In [9]:
figB, axB = plt.subplots(ncols=3, figsize=(12, 7))

colors = {'#0' : 'k', '#16' : 'r'}

similarities = si.compute_template_similarity(waveforms)

## Plotting example of pair with selected similarity
si.plot_unit_templates(waveforms, plot_templates=True, axes=[axB[0]], unit_ids=['#0'], unit_colors=colors)
si.plot_unit_templates(waveforms, plot_templates=True, axes=[axB[0]], unit_ids=['#16'], unit_colors=colors)
axB[0].set_title('(#0, #16) similarity %02g' %similarities[0, 16])   

colors = {'#0' : 'k', '#10' : 'r'}
si.plot_unit_templates(waveforms, plot_templates=True, axes=[axB[1]], unit_ids=['#0'], unit_colors=colors)
si.plot_unit_templates(waveforms, plot_templates=True, axes=[axB[1]], unit_ids=['#10'], unit_colors=colors)    
axB[1].set_title('(#0, #10) similarity %02g' %similarities[0, 10])

colors = {'#0' : 'k', '#1' : 'r'}
si.plot_unit_templates(waveforms, plot_templates=True, axes=[axB[2]], unit_ids=['#0'], unit_colors=colors)
si.plot_unit_templates(waveforms, plot_templates=True, axes=[axB[2]], unit_ids=['#1'], unit_colors=colors)    
axB[2].set_title('(#0, #1) similarity %02g' %similarities[0, 1])
figB.tight_layout()
In [11]:
figC, axC = pltsubplotsfigure()

## Plotting the similarity matrix
im = axC.imshow(similarities, cmap='viridis',
                aspect='auto',
                origin='lower',
                interpolation='none',
                extent=(-0.5, n_cell-1+0.5, -0.5, n_cell-1+0.5))
axC.set_xlabel('# cell')
axC.set_ylabel('# cell')
plt.colorbar(im, ax=axC, label='cosine similarity')
Out[11]:
<matplotlib.colorbar.Colorbar at 0x7f6c27a39208>
In [12]:
figDE, axDE = plt.subplots(nrows=2)

centers = np.array([v for v in si.compute_unit_centers_of_mass(waveforms).values()])
real_centers = mearec_object.template_locations[:]

distances = scipy.spatial.distance_matrix(centers, centers)
real_distances =  scipy.spatial.distance_matrix(real_centers, real_centers)

# Plotting the distribution of similarities as function of distance (either real or estimated)
axDE[0].plot(distances.flatten(), similarities.flatten(), '.', label='Center of Mass')
axDE[0].plot(real_distances.flatten(), similarities.flatten(), '.', label='Real position')
axDE[0].legend()
axDE[0].set_xlabel('distances (um)')
axDE[0].set_ylabel('cosine similarity')

x, y = np.histogram(similarities.flatten(), 10)
axD[1].bar(y[1:], x/float(x.sum()), width=y[1]-y[0])
axD[1].set_xlabel('cosine similarity')
axD[1].set_ylabel('probability')
Out[12]:
Text(0, 0.5, 'probability')
In [14]:
## For the CC, you should uncomment the following line, but the figure was assembled
w = si.plot_crosscorrelograms(sorting_gt, ['#%s' %i for i in range(0,3)], 
                              bin_ms=bin_ms, window_ms=window_ms, symmetrize=True)
figF = w.figure
Out[14]:
<spikeinterface.widgets.correlograms.CrossCorrelogramsWidget at 0x7f6bae84ae80>
In [16]:
figGH, axGH = plt.subplots(nrows=2)

ccs, lags = si.compute_correlograms(sorting_gt, bin_ms=bin_ms, window_ms=window_ms, symmetrize=True)
ccs = ccs.reshape(n_cell**2, ccs.shape[2])
mask = np.ones(n_cell**2).astype(np.bool)
mask[np.arange(0, n_cell**2, n_cell) + np.arange(n_cell)] = False
mean_cc = np.mean(ccs[mask], 0)
std_cc = np.std(ccs[mask], 0)

## Plotting the average CC
xaxis = (lags[:-1] - lags[:-1].mean())
axGH[0].plot(xaxis, mean_cc, lw=2, c='r')
axGH[0].fill_between(xaxis, mean_cc-std_cc,mean_cc+std_cc, color='0.5', alpha=0.5)
axGH[0].set_xlabel('time (ms)')
axGH[0].set_ylabel('cross correlation')
ymin, ymax = axGH[0].get_ylim()
axGH[0].plot([-lag_time,-lag_time],[ymin,ymax],'k--')
axGH[0].plot([lag_time,lag_time],[ymin,ymax],'k--')

mask = np.zeros(n_cell**2).astype(np.bool)
mask[np.arange(0, n_cell**2, n_cell) + np.arange(n_cell)] = True
mean_cc = np.mean(ccs[mask], 0)
std_cc = np.std(ccs[mask], 0)

xaxis = (lags[:-1] - lags[:-1].mean())
axGH[1].plot(xaxis, mean_cc, lw=2, c='r')
axGH[1].fill_between(xaxis, mean_cc-std_cc,mean_cc+std_cc, color='0.5', alpha=0.5)
axGH[1].set_ylabel('auto correlation')
ymin, ymax = axGH[1].get_ylim()
axGH[1].plot([-lag_time,-lag_time],[ymin,ymax],'k--')
axGH[1].plot([lag_time,lag_time],[ymin,ymax],'k--')
Out[16]:
[<matplotlib.lines.Line2D at 0x7f6c702a1358>]
In [18]:
## Plotting timeseries
w = si.plot_timeseries(rec, time_range=(5,5.1), channel_ids=['%s' %i for i in range(1,nb_traces)], color='k')
figI = w.figure
Out[18]:
<spikeinterface.widgets.timeseries.TimeseriesWidget at 0x7f6badf61f60>

Collision paper generate recordings

Generation of the recordings

In this notebook, we will generate all the recordings with MEArec that will be necessary to populate the study and compare the sorters. First, we need to create a function that will, given a dictionary of parameter, generate a single recording. The recording parameters can be defined as follows

In [4]:
import os
import sys
import shutil
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt

import MEArec as mr
import spikeinterface.full as si
In [8]:
sys.path.append('../utils/')

from corr_spike_trains import CorrelatedSpikeGenerator
In [3]:
generation_params = {
    'probe' : 'Neuronexus-32', #layout of the probe used
    'duration' : 30*60, #total duration of the recording
    'n_cell' : 20, # number of cells that will be injected
    'fs' : 30000., # sampling rate
    'lag_time' : 0.002,  # half refractory period in ms
    'make_plots' : True,
    'generate_recording' : True,
    'noise_level' : 5,
    'templates_seed' : 42,
    'noise_seed' : 42,
    'global_path' : os.path.abspath('../'),
    'study_number' : 0,
    'save_plots' : True,
    'method' : 'brette', # 'poisson' | 'brette'
    'corr_level' : 0,
    'rate_level' : 5, #Hz
    'nb_recordings' : 5
}

With these parameters, we will create 20 neurons, and correlation levels will be generated via the mixture process of [Brette et al, 2009]. The function to generate a single recording is defined as follows. It assumes that you have, in your folder, a file named ../data/templates/templates_{probe}_100.h5 with all the pre-generated templates that will be used by MEArec

In [5]:
def generate_single_recording(params=generation_params):

    paths = {}
    paths['basedir'] = params['global_path']
    paths['data'] = None

    if paths['data'] == None:
        paths['data'] = os.path.join(paths['basedir'], 'data')

    paths['templates'] =  os.path.join(paths['data'], 'templates')
    paths['recordings'] = os.path.join(paths['data'], 'recordings') 

    for i in paths.values():
        if not os.path.exists(i):
            os.makedirs(i)

    probe = params['probe']
    n_cell = params['n_cell']
    noise_level = params['noise_level']
    study_number = params['study_number']
    corr_level = params['corr_level']
    rate_level = params['rate_level']

    template_filename = os.path.join(paths['templates'], f'templates_{probe}_100.h5')
    recording_filename = os.path.join(paths['recordings'], f'rec{study_number}_{n_cell}cells_{noise_level}noise_{corr_level}corr_{rate_level}rate_{probe}.h5')
    plot_filename = os.path.join(paths['recordings'], f'rec{study_number}_{n_cell}cells_{noise_level}noise_{corr_level}corr_{rate_level}rate_{probe}.pdf')

    spikerate = params['rate_level']
    n_spike_alone = int(spikerate * params['duration'])

    print('Total target rate:', params['rate_level'], "Hz")
    print('Basal rate:', spikerate, "Hz")


    # collision lag range
    lag_sample = int(params['lag_time'] * params['fs'])

    refactory_period = 2 * params['lag_time']

    spiketimes = []

    if params['method'] == 'poisson':
        print('Spike trains generated as independent poisson sources')
        
        for i in range(params['n_cell']):
            
            #~ n = n_spike_alone + n_collision_by_pair * (params['n_cell'] - i - 1)
            n = n_spike_alone
            #~ times = np.random.rand(n_spike_alone) * params['duration']
            times = np.random.rand(n) * params['duration']
            
            times = np.sort(times)
            spiketimes.append(times)

    elif params['method'] == 'brette':
        print('Spike trains generated as compound mixtures')
        C = np.ones((params['n_cell'], params['n_cell']))
        C = params['corr_level'] * np.maximum(C, C.T)
        #np.fill_diagonal(C, 0*np.ones(params['n_cell']))

        rates = rates = params['rate_level']*np.ones(params['n_cell'])

        cor_spk = CorrelatedSpikeGenerator(C, rates, params['n_cell'])
        cor_spk.find_mixture(iter=1e4)
        res = cor_spk.mixture_process(tauc=refactory_period/2, t=params['duration'])
        
        # make neo spiketrains
        for i in range(params['n_cell']):
            #~ print(spiketimes[i])
            mask = res[:, 0] == i
            times = res[mask, 1]
            times = np.sort(times)
            mask = (times > 0) * (times < params['duration'])
            times = times[mask]
            spiketimes.append(times)


    # remove refactory period
    for i in range(params['n_cell']):
        times = spiketimes[i]
        ind, = np.nonzero(np.diff(times) < refactory_period)
        ind += 1
        times = np.delete(times, ind)
        assert np.sum(np.diff(times) < refactory_period) ==0
        spiketimes[i] = times

    # make neo spiketrains
    spiketrains = []
    for i in range(params['n_cell']):
        mask = np.where(spiketimes[i] > 0)
        spiketimes[i] = spiketimes[i][mask] 
        spiketrain = neo.SpikeTrain(spiketimes[i], units='s', t_start=0*pq.s, t_stop=params['duration']*pq.s)
        spiketrain.annotate(cell_type='E')
        spiketrains.append(spiketrain)

    # check with sanity plot here
    if params['make_plots']:
        
        # count number of spike per units
        fig, axs = plt.subplots(2, 2)
        count = [st.size for st in spiketrains]
        ax = axs[0, 0]
        simpleaxis(ax)
        pairs = []
        collision_count_by_pair = []
        collision_count_by_units = np.zeros(n_cell)
        for i in range(n_cell):
            for j in range(i+1, n_cell):
                times1 = spiketrains[i].rescale('s').magnitude
                times2 = spiketrains[j].rescale('s').magnitude
                matching_event = make_matching_events((times1*params['fs']).astype('int64'), (times2*params['fs']).astype('int64'), lag_sample)
                pairs.append(f'{i}-{j}')
                collision_count_by_pair.append(matching_event.size)
                collision_count_by_units[i] += matching_event.size
                collision_count_by_units[j] += matching_event.size
        ax.plot(np.arange(len(collision_count_by_pair)), collision_count_by_pair)
        ax.set_xticks(np.arange(len(collision_count_by_pair)))
        ax.set_xticklabels(pairs)
        ax.set_ylim(0, max(collision_count_by_pair) * 1.1)
        ax.set_ylabel('# Collisions')
        ax.set_xlabel('Pairs')

        # count number of spike per units
        count_total = np.array([st.size for st in spiketrains])
        count_not_collision = count_total - collision_count_by_units

        ax = axs[1, 0]
        simpleaxis(ax)
        ax.bar(np.arange(n_cell).astype(np.int)+1, count_not_collision, color='g')
        ax.bar(np.arange(n_cell).astype(np.int)+1, collision_count_by_units, bottom =count_not_collision, color='r')
        ax.set_ylabel('# spikes')
        ax.set_xlabel('Cell id')
        ax.legend(('Not colliding', 'Colliding'), loc='best')

        # cross corrlogram
        ax = axs[0, 1]
        simpleaxis(ax)
        counts = []
        for i in range(n_cell):
            for j in range(i+1, n_cell):
                times1 = spiketrains[i].rescale('s').magnitude
                times2 = spiketrains[j].rescale('s').magnitude
                matching_event = make_matching_events((times1*params['fs']).astype('int64'), (times2*params['fs']).astype('int64'), lag_sample)
                
                #~ ax = axs[i, j]
                all_lag = matching_event['delta_frame']  / params['fs']
                count, bins  = np.histogram(all_lag, bins=np.arange(-params['lag_time'], params['lag_time'], params['lag_time']/20))
                #~ ax.bar(bins[:-1], count, bins[1] - bins[0])
                ax.plot(1000*bins[:-1], count, bins[1] - bins[0], c='0.5')
                counts += [count]
        counts = np.array(counts)
        counts = np.mean(counts, 0)
        ax.plot(1000*bins[:-1], counts, bins[1] - bins[0], c='r')
        ax.set_xlabel('Lags [ms]')
        ax.set_ylabel('# Collisions')

        ax = axs[1, 1]
        simpleaxis(ax)
        ratios = []
        for i in range(n_cell):
            nb_spikes = len(spiketrains[i])
            nb_collisions = 0
            times1 = spiketrains[i].rescale('s').magnitude
            for j in list(range(0, i)) + list(range(i+1, n_cell)):
                times2 = spiketrains[j].rescale('s').magnitude
                matching_event = make_matching_events((times1*params['fs']).astype('int64'), (times2*params['fs']).astype('int64'), lag_sample)
                nb_collisions += matching_event.size

            if nb_collisions > 0:
                ratios += [nb_spikes / nb_collisions]
            else:
                ratios += [0]

        ax.bar([0], [np.mean(ratios)], yerr=[np.std(ratios)])
        ax.set_ylabel('# spikes / # collisions')
        plt.tight_layout()

        if params['save_plots']:
            plt.savefig(plot_filename)
        else:
            plt.show()
        plt.close()

    if params['generate_recording']:
        spgen = mr.SpikeTrainGenerator(spiketrains=spiketrains)
        rec_params = mr.get_default_recordings_params()
        rec_params['recordings']['fs'] = params['fs']
        rec_params['recordings']['sync_rate'] = None
        rec_params['recordings']['sync_jitter'] = 5
        rec_params['recordings']['noise_level'] = params['noise_level']
        rec_params['recordings']['filter'] = False
        rec_params['spiketrains']['duration'] = params['duration']
        rec_params['spiketrains']['n_exc'] = params['n_cell']
        rec_params['spiketrains']['n_inh'] = 0
        rec_params['recordings']['chunk_duration'] = 10.
        rec_params['templates']['n_overlap_pairs'] = None
        rec_params['templates']['min_dist'] = 0
        rec_params['seeds']['templates'] = params['templates_seed']
        rec_params['seeds']['noise'] = params['noise_seed']
        recgen = mr.gen_recordings(params=rec_params, spgen=spgen, templates=template_filename, verbose=True)
        mr.save_recording_generator(recgen, filename=recording_filename)

Once this function is created, we can create an additional function that will generate several recordings, with different suffix/seeds:

In [6]:
def generate_recordings(params=generation_params):
    for i in range(params['nb_recordings']):
        generation_params['study_number'] = i
        generation_params['templates_seed'] = i
        generation_params['noise_seed'] = i
        generate_single_recording(generation_params)

And now, we have all the required tools to create our recordings. By default, they will all be saved in the folder ../recordings/

In [7]:
## Provide the different rate and correlations levels you want to generate
rate_levels = [5, 10, 15]
corr_levels = [0, 0.1, 0.2]
generation_params['nb_recordings'] = 5 #Number of recordings per conditions
In [ ]:
result = {}

for rate_level in rate_levels:
    for corr_level in corr_levels:

        generation_params['rate_level'] = rate_level
        generation_params['corr_level'] = corr_level
        generate_recordings(generation_params)

Generation of the study objects

Since the recordings have been generated, we now need to create Study objects for spikeinterface, and run the sorters on all these recordings. Be careful that by default, this can create quite a large amount of data, if you have numerous rate/correlation levels and/or number of recordings and/or sorters. First, we need to tell spikeinterface how to find the sorters

In [11]:
ironclust_path = '/media/cure/Secondary/pierre/softwares/ironclust'
kilosort1_path = '/media/cure/Secondary/pierre/softwares/Kilosort-1.0'
kilosort2_path = '/media/cure/Secondary/pierre/softwares/Kilosort-2.0'
kilosort3_path = '/media/cure/Secondary/pierre/softwares/Kilosort-3.0'
hdsort_path = '/media/cure/Secondary/pierre/softwares/HDsort'
os.environ["KILOSORT_PATH"] = kilosort1_path
os.environ["KILOSORT2_PATH"] = kilosort2_path
os.environ["KILOSORT3_PATH"] = kilosort3_path
os.environ['IRONCLUST_PATH'] = ironclust_path
os.environ['HDSORT_PATH'] = hdsort_path

And then we need to create a function that will, given a list of recordings, create a study and run all the sorters

In [13]:
def generate_study(params, keep_data=True):
    paths = {}
    paths['basedir'] = params['global_path']
    paths['data'] = None

    if paths['data'] == None:
        paths['data'] = os.path.join(paths['basedir'], 'data')

    paths['templates'] =  os.path.join(paths['data'], 'templates')
    paths['recordings'] = os.path.join(paths['data'], 'recordings')
    paths['study'] = os.path.join(paths['data'], 'study')
    
    for i in paths.values():
        if not os.path.exists(i):
            os.makedirs(i)

    probe = params['probe']
    n_cell = params['n_cell']
    noise_level = params['noise_level']
    study_number = params['study_number']
    corr_level = params['corr_level']
    rate_level = params['rate_level']

    paths['mearec_filename'] = []

    study_folder = os.path.join(paths['study'], f'{n_cell}cells_{noise_level}noise_{corr_level}corr_{rate_level}rate_{probe}')
    study_folder = Path(study_folder)

    if params['reset_study'] and os.path.exists(study_folder):
        shutil.rmtree(study_folder)

    print('Availables sorters:')
    si.print_sorter_versions()

    gt_dict = {}

    if not os.path.exists(study_folder):

        for i in range(params['nb_recordings']):
            paths['mearec_filename'] += [os.path.join(paths['recordings'], f'rec{i}_{n_cell}cells_{noise_level}noise_{corr_level}corr_{rate_level}rate_{probe}.h5')]

        print('Availables recordings:')
        print(paths['mearec_filename'])

        
        for count, file in enumerate(paths['mearec_filename']):
            rec  = si.MEArecRecordingExtractor(file)
            sorting_gt = si.MEArecSortingExtractor(file)
            gt_dict['rec%d' %count] = (rec, sorting_gt)

        study = si.GroundTruthStudy.create(study_folder, gt_dict, n_jobs=-1, chunk_memory='1G', progress_bar=True)
        study.run_sorters(params['sorter_list'], verbose=False, docker_images=params['docker_images'])
        print("Study created!")
    else:
        study = si.GroundTruthStudy(study_folder)
        if params['relaunch'] == 'all':
            if_exist = 'overwrite'
        elif params['relaunch'] == 'some':
            if_exist = 'keep'

        if params['relaunch'] in ['all', 'some']:
            study.run_sorters(params['sorter_list'], verbose=False, mode_if_folder_exists=if_exist, docker_images=params['docker_images'])
            print("Study loaded!")

    study.copy_sortings()

    if not keep_data:

        for sorter in params['sorter_list']:

            for rec in ['rec%d' %i for i in range(params['nb_recordings'])]:
                sorter_path = os.path.join(study_folder, 'sorter_folders', rec, sorter)
                if os.path.exists(sorter_path):
                    for f in os.listdir(sorter_path):
                        if f != 'spikeinterface_log.json':
                            full_file = os.path.join(sorter_path, f)
                            try:
                                if os.path.isdir(full_file):
                                    shutil.rmtree(full_file)
                                else:
                                    os.remove(full_file)
                            except Exception:
                                pass
        for file in paths['mearec_filename']:
            os.remove(file)

    return study

This function will take a dictionary of inputs (the same as for generating the recordings), and looping over all the possible recordings for a given condition (probe, rate, correlation levels) it will create a study in the path ../study/, running all the sorters on the recordings. This can take a lot of time, depending on the number of recordings/sorters. Note also that by default, the original recorindgs generated by MEArec are kept, and thus duplicated in the study folder. If you want to delete the original recordings (they are not needed for further analysis) you can set keep_data=False

In [14]:
study_params = generation_params.copy()
study_params['sorter_list'] = ['yass', 'kilosort', 'kilosort2', 'kilosort3', 'spykingcircus', 'tridesclous', 'ironclust', 'herdingspikes', 'hdsort']
study_params['docker_images'] = {'yass' : 'spikeinterface/yass-base:2.0.0'} #If some sorters are installed via docker
study_params['relaunch'] = 'all' #If you want to relaunch the sorters. 
study_params['reset_study'] = False #If you want to reset the study (delete everything)
In [ ]:
all_studies = {}
for rate_level in rate_levels:
    for corr_level in corr_levels:

        study_params['rate_level'] = rate_level
        study_params['corr_level'] = corr_level
        all_studies[corr_level, rate_level] = generate_study(study_params)

And this is it! Now you should have several studies, each of them with several recordings that have be analyzed by several sorters, in a structured manner (as function of rate/correlations levels)

probeinterface paper figures

Figure for probeinterface paper

Here a notebook to reproduce figures for paper

ProbeInterface: a unified framework for probe handling in extracellular electrophysiology

In [2]:
from probeinterface import plotting, io, Probe, ProbeGroup, get_probe
from probeinterface.plotting import plot_probe_group

import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline
In [3]:
# create contact positions
positions = np.zeros((32, 2))
positions[:, 0] = [0] * 8 + [50] * 8 + [200] * 8 + [250] * 8
positions[:, 1] = list(range(0, 400, 50)) * 4
# create an empty probe object with coordinates in um
probe0 = Probe(ndim=2, si_units='um')
# set contacts
probe0.set_contacts(positions=positions, shapes='circle',shape_params={'radius': 10})
# create probe shape (optional)
polygon = [(-20, 480), (-20, -30), (20, -110), (70, -30), (70, 450),
           (180, 450), (180, -30), (220, -110), (270, -30), (270, 480)]
probe0.set_planar_contour(polygon)
In [4]:
# duplicate the probe and move it horizontally
probe1 = probe0.copy()
# move probe by 600 um in x direction
probe1.move([600, 0])

# Create a probegroup
probegroup = ProbeGroup()
probegroup.add_probe(probe0)
probegroup.add_probe(probe1)
In [5]:
fig2, ax2 = plt.subplots(figsize=(10,7))
plot_probe_group(probegroup, ax=ax2)
In [6]:
fig2.savefig("fig2.pdf")
In [7]:
probe0 = get_probe('cambridgeneurotech', 'ASSY-156-P-1')
probe1 = get_probe('neuronexus', 'A1x32-Poly3-10mm-50-177')
probe1.move([1000, -100])

probegroup = ProbeGroup()
probegroup.add_probe(probe0)
probegroup.add_probe(probe1)

fig3, ax3 = plt.subplots(figsize=(10,7))
plot_probe_group(probegroup, ax=ax3)
In [8]:
fig3.savefig("fig3.pdf")
In [9]:
manufacturer = 'cambridgeneurotech'
probe_name = 'ASSY-156-P-1'

probe = get_probe(manufacturer, probe_name)
print(probe)
cambridgeneurotech - ASSY-156-P-1 - 64ch - 4shanks
In [10]:
probe.wiring_to_device('ASSY-156>RHD2164')

fig4, ax4 = plt.subplots(figsize=(12,7))
plotting.plot_probe(probe, with_device_index=True, with_contact_id=True, title=False, ax=ax4)
ax4.set_xlim(-100, 400)
ax4.set_ylim(-150, 100)
Out[10]:
(-150.0, 100.0)
In [11]:
fig4.savefig("fig4.pdf")
In [12]:
probe.device_channel_indices
Out[12]:
array([47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 32, 31,
       30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14,
       13, 12, 11, 10,  9,  8,  7,  6,  5,  4,  3,  2,  1,  0, 63, 62, 61,
       60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, 48])
In [13]:
probe.to_dataframe(complete=True)
Out[13]:
x y contact_shapes width height shank_ids contact_ids device_channel_indices si_units plane_axis_x_0 plane_axis_x_1 plane_axis_y_0 plane_axis_y_1
0 522.5 137.5 rect 11.0 15.0 2 1 47 um 1.0 0.0 0.0 1.0
1 500.0 50.0 rect 11.0 15.0 2 2 46 um 1.0 0.0 0.0 1.0
2 522.5 187.5 rect 11.0 15.0 2 3 45 um 1.0 0.0 0.0 1.0
3 500.0 125.0 rect 11.0 15.0 2 4 44 um 1.0 0.0 0.0 1.0
4 772.5 112.5 rect 11.0 15.0 3 5 43 um 1.0 0.0 0.0 1.0
... ... ... ... ... ... ... ... ... ... ... ... ... ...
59 772.5 37.5 rect 11.0 15.0 3 60 52 um 1.0 0.0 0.0 1.0
60 750.0 150.0 rect 11.0 15.0 3 61 51 um 1.0 0.0 0.0 1.0
61 750.0 50.0 rect 11.0 15.0 3 62 50 um 1.0 0.0 0.0 1.0
62 750.0 125.0 rect 11.0 15.0 3 63 49 um 1.0 0.0 0.0 1.0
63 772.5 12.5 rect 11.0 15.0 3 64 48 um 1.0 0.0 0.0 1.0

64 rows × 13 columns

Quick benchmarck with new API and new sorters (april 2021)

Quick benchmark with new spikeinetrface API with new sorters

In spring 2021 the spikeinterface is refactored deeply.

During this refactoring some sorters have been added.

Here quick benchmark with one simulated dataset with MEArec.

In [7]:
%matplotlib inline
In [8]:
from pathlib import Path
import os
import shutil
from pprint import pprint
import getpass


import numpy as np
import matplotlib.pyplot as plt

import MEArec as mr
import neo
import quantities as pq


import spikeinterface.extractors  as se
import spikeinterface.widgets  as sw
import spikeinterface.sorters  as ss

from spikeinterface.comparison import GroundTruthStudy
In [9]:
basedir = '/mnt/data/sam/DataSpikeSorting/'

basedir = Path(basedir)

workdir = basedir / 'mearec_bench_2021'

study_folder = workdir /'study_mearec_march_2021'

tmp_folder = workdir / 'tmp'
tmp_folder.mkdir(parents=True, exist_ok=True)

generate recording with mearec

In [ ]:
template_filename = workdir / 'templates_Neuronexus-32_100.h5'
probe = 'Neuronexus-32'
n_cell = 15
duration = 10 * 60.

recording_filename = workdir /  f'recordings_{n_cell}cells_{probe}_{duration:0.0f}s.h5'


fs = 30000.


#~ spgen = mr.SpikeTrainGenerator()
rec_params = mr.get_default_recordings_params()

rec_params['recordings']['fs'] = fs
rec_params['recordings']['sync_rate'] = None
rec_params['recordings']['sync_jitter'] = 5
rec_params['recordings']['noise_level'] = 5
rec_params['recordings']['filter'] = False
rec_params['recordings']['chunk_duration'] = 10.
rec_params['spiketrains']['duration'] = duration
rec_params['spiketrains']['n_exc'] = n_cell
rec_params['spiketrains']['n_inh'] = 0
rec_params['templates']['n_overlap_pairs'] = None
rec_params['templates']['min_dist'] = 0

recgen = mr.gen_recordings(params=rec_params, #spgen=spgen, 
            templates=template_filename, verbose=True,
            n_jobs=1, tmp_mode='memmap',
            tmp_folder=str(tmp_folder))

mr.save_recording_generator(recgen, filename=recording_filename)

set sorter path

In [3]:
user = getpass.getuser()

kilosort_path = f'/home/{user}/Documents/SpikeInterface/code_sorters/KiloSort1'
ss.KilosortSorter.set_kilosort_path(kilosort_path)

kilosort2_path = f'/home/{user}/Documents/SpikeInterface/code_sorters/Kilosort2'
ss.Kilosort2Sorter.set_kilosort2_path(kilosort2_path)

kilosort2_5_path = f'/home/{user}/Documents/SpikeInterface/code_sorters/Kilosort2.5'
ss.Kilosort2_5Sorter.set_kilosort2_5_path(kilosort2_path)

kilosort3_path = f'/home/{user}/Documents/SpikeInterface/code_sorters/Kilosort3'
ss.Kilosort3Sorter.set_kilosort3_path(kilosort3_path)

ironclust_path = f'/home/{user}/Documents/SpikeInterface/code_sorters/ironclust/'
ss.IronClustSorter.set_ironclust_path(ironclust_path)
Setting KILOSORT_PATH environment variable for subprocess calls to: /home/samuel.garcia/Documents/SpikeInterface/code_sorters/KiloSort1
Setting KILOSORT2_PATH environment variable for subprocess calls to: /home/samuel.garcia/Documents/SpikeInterface/code_sorters/Kilosort2
Setting KILOSORT2_5_PATH environment variable for subprocess calls to: /home/samuel.garcia/Documents/SpikeInterface/code_sorters/Kilosort2
Setting KILOSORT3_PATH environment variable for subprocess calls to: /home/samuel.garcia/Documents/SpikeInterface/code_sorters/Kilosort3
Setting IRONCLUST_PATH environment variable for subprocess calls to: /home/samuel.garcia/Documents/SpikeInterface/code_sorters/ironclust

create study

In [6]:
mearec_filename = workdir / 'recordings_15cells_Neuronexus-32_600s.h5'

if study_folder.is_dir():
    shutil.rmtree(study_folder)

rec  = se.MEArecRecordingExtractor(mearec_filename)
sorting_gt = se.MEArecSortingExtractor(mearec_filename)
print(rec)
print(sorting_gt)

gt_dict = {'rec0' : (rec, sorting_gt) }

study = GroundTruthStudy.create(study_folder, gt_dict)
MEArecRecordingExtractor: 32 channels - 1 segments - 30.0kHz
  file_path: /mnt/data/sam/DataSpikeSorting/mearec_bench_2021/recordings_15cells_Neuronexus-32_600s.h5
MEArecSortingExtractor: 15 units - 1 segments - 30.0kHz
  file_path: /mnt/data/sam/DataSpikeSorting/mearec_bench_2021/recordings_15cells_Neuronexus-32_600s.h5
write_binary_recording with n_jobs 1  chunk_size None

plot probe

In [14]:
study = GroundTruthStudy(study_folder)
rec = study.get_recording()
probe = rec.get_probe()
print(probe)
from probeinterface.plotting import plot_probe
plot_probe(probe)
Probe - 32ch
Out[14]:
(<matplotlib.collections.PolyCollection at 0x7f93854cc370>,
 <matplotlib.collections.PolyCollection at 0x7f947882e7c0>)

run sorters

In [ ]:
sorter_list = ['spykingcircus', 'kilosort2', 'kilosort3', 'tridesclous']
study = GroundTruthStudy(study_folder)
study.run_sorters(sorter_list, mode_if_folder_exists='overwrite', verbose=False)
study.copy_sortings()

collect results

In [4]:
study = GroundTruthStudy(study_folder)
study.copy_sortings()


study.run_comparisons(exhaustive_gt=True, delta_time=1.5)


comparisons = study.comparisons
dataframes = study.aggregate_dataframes()
In [10]:
for (rec_name, sorter_name), comp in comparisons.items():
    print()
    print('*'*20)
    print(rec_name, sorter_name)
    print(comp.count_score)
********************
rec0 spykingcircus
              tp    fn    fp num_gt num_tested tested_id
gt_unit_id                                              
#0             0  2772     0   2772          0        -1
#1          2305     0  2127   2305       4432         0
#2             0  3009     0   3009          0        -1
#3             0  2503     0   2503          0        -1
#4          3135     0     4   3135       3139         2
#5             0  2081     0   2081          0        -1
#6          2192     0     2   2192       2194         5
#7          2723     0    55   2723       2778         3
#8             0  3453     0   3453          0        -1
#9             0  2334     0   2334          0        -1
#10         2280    15     8   2295       2288        11
#11         2588     8    12   2596       2600        10
#12         2721   333  1503   3054       4224         8
#13            0  3020     0   3020          0        -1
#14         3612     0  1070   3612       4682         6

********************
rec0 kilosort2
              tp   fn  fp num_gt num_tested tested_id
gt_unit_id                                           
#0          2765    7   6   2772       2771        29
#1          2299    6   0   2305       2299         8
#2          3008    1   0   3009       3008        19
#3          2502    1   2   2503       2504        25
#4          3117   18   0   3135       3117        10
#5          2076    5   1   2081       2077         7
#6          2188    4   0   2192       2188         3
#7          2717    6   0   2723       2717        26
#8          3447    6   0   3453       3447         4
#9          2323   11   5   2334       2328         6
#10         2112  183  54   2295       2166        31
#11         2592    4   0   2596       2592        11
#12         3051    3   0   3054       3051        14
#13         3019    1   0   3020       3019         1
#14         3603    9   0   3612       3603        22

********************
rec0 tridesclous
              tp  fn  fp num_gt num_tested tested_id
gt_unit_id                                          
#0          2727  45  22   2772       2749        14
#1          2294  11   0   2305       2294         4
#2          3003   6   1   3009       3004         1
#3          2467  36  20   2503       2487         9
#4          3123  12   9   3135       3132        13
#5          2047  34   6   2081       2053        10
#6          2159  33  12   2192       2171         7
#7          2695  28   0   2723       2695         6
#8          3420  33   1   3453       3421         5
#9          2293  41  63   2334       2356        12
#10         2230  65  24   2295       2254         3
#11         2532  64  18   2596       2550         2
#12         3023  31  21   3054       3044         0
#13         2979  41  10   3020       2989         8
#14         3588  24  12   3612       3600        11

********************
rec0 kilosort3
              tp    fn  fp num_gt num_tested tested_id
gt_unit_id                                            
#0          2734    38  12   2772       2746         3
#1          2302     3   0   2305       2302        29
#2          3005     4   2   3009       3007        77
#3          2450    53  96   2503       2546        74
#4          2906   229  26   3135       2932         7
#5          2067    14  42   2081       2109         2
#6          1381   811  56   2192       1437        14
#7          2712    11   2   2723       2714        76
#8          3447     6   0   3453       3447         0
#9          2288    46   3   2334       2291         1
#10         1424   871  52   2295       1476        35
#11            0  2596   0   2596          0        -1
#12         3041    13   0   3054       3041        23
#13         1580  1440   0   3020       1580        11
#14         3573    39  97   3612       3670        32

Agreement matrix

In [11]:
for (rec_name, sorter_name), comp in comparisons.items():
    fig, ax = plt.subplots()
    sw.plot_agreement_matrix(comp, ax=ax)
    fig.suptitle(rec_name+'   '+ sorter_name)

Accuracy vs SNR

In [ ]:
 

Compare old vs new spikeinterface API

Compare "old" vs "new " spikeinterface API

Author : Samuel Garcia 29 March 2021

In spring 2021, the spikeinterface team plan a "big refactoring" of the spikeinterface tool suite.

Main changes are:

  • use neo as much as possible for extractors
  • handle multi segment
  • improve performance (pre and post processing)
  • add A WaveformExtractor class

Here I will benchmark 2 aspects of the "new API":

  • filter with 10 workers on a multi core machine
  • extractor waveform 1 worker vs 10 workers

The becnhmark is done a 10 min spikeglx file with 384 channels.

The sorting is done with kilosort3.

My machine is Intel(R) Xeon(R) Silver 4210 CPU @ 2.20GHz 2 CPU with 20 core each.

In [5]:
from pathlib import Path
import shutil
import time
import matplotlib.pyplot as plt

base_folder = Path('/mnt/data/sam/DataSpikeSorting/eduarda_arthur') 
data_folder = base_folder / 'raw_awake'

Filter with OLD API

Here we :

  1. open the file
  2. lazy filter
  3. cache it
  4. dump to json

The "cache" step is in fact the "compute and save" step.

In [6]:
import spikeextractors as se
import spiketoolkit as st

print('spikeextractors version', se.__version__)
print('spiketoolkit version', st.__version__)

# step 1: open
file_path = data_folder / 'raw_awake_01_g0_t0.imec0.ap.bin'
recording = se.SpikeGLXRecordingExtractor(file_path)

# step 2: lazy filter
rec_filtered = st.preprocessing.bandpass_filter(recording,  freq_min=300. freq_max=6000.)
print(rec_filtered)

save_folder = base_folder / 'raw_awake_filtered_old'
if save_folder.is_dir():
    shutil.rmtree(save_folder)
save_folder.mkdir()

save_file = save_folder / 'filetred_recording.dat'
dump_file = save_folder / 'filetred_recording.json'

# step 3: cache
t0 = time.perf_counter()
cached = se.CacheRecordingExtractor(rec_filtered, chunk_mb=50, n_jobs=10, 
    save_path=save_file)
t1 = time.perf_counter()
run_time_filter_old = t1-t0
print('Old spikeextractors cache', run_time_filter_old)

# step : dump
cached.dump_to_json(dump_file)
spikeextractors version 0.9.5
spiketoolkit version 0.7.4
<spiketoolkit.preprocessing.bandpass_filter.BandpassFilterRecording object at 0x7f648d3ee130>
Old spikeextractors cache 801.9439885600004

Filter with NEW API

Here we :

  1. open the file
  2. lazy filter
  3. save it

The "save" step is in fact the "compute and save" step.

In [7]:
 
import spikeinterface as si

import spikeinterface.extractors as se
import spikeinterface.toolkit as st
print('spikeinterface version', si.__version__)

# step 1: open
recording = se.SpikeGLXRecordingExtractor(data_folder)
print(recording)

# step 2: lazy filter
rec_filtered =st.bandpass_filter(recording,  freq_min=300., freq_max=6000.)
print(rec_filtered)


filter_path = base_folder / 'raw_awake_filtered'
if filter_path.is_dir():
    shutil.rmtree(filter_path)

# step 3 : compute and save with 10 workers
t0 = time.perf_counter()
cached = rec_filtered.save(folder=filter_path,
    format='binary', dtype='int16',
    n_jobs=10,  total_memory="50M", progress_bar=True)
t1 = time.perf_counter()
run_time_filter_new = t1 -t0
print('New spikeinterface filter + save binary', run_time_filter_new)
spikeinterface version 0.90.0
SpikeGLXRecordingExtractor: 385 channels - 1 segments - 30.0kHz
BandpassFilterRecording: 385 channels - 1 segments - 30.0kHz
write_binary_recording with n_jobs 10  chunk_size 3246
write_binary_recording: 100%|██████████| 5546/5546 [00:51<00:00, 108.39it/s]
New spikeinterface filter + save binary 54.79437772196252

Extract waveform with OLD API

Here we use get_unit_waveforms from toolkit.

We do the computation with 1 and then 10 jobs.

In [21]:
from spikeextractors.baseextractor import BaseExtractor
import spikeextractors as se
import spiketoolkit as st
print('spikeextractors version', se.__version__)
print('spiketoolkit version', st.__version__)
spikeextractors version 0.9.5
spiketoolkit version 0.7.4
In [24]:
save_folder = base_folder / 'raw_awake_filtered_old'
dump_file = save_folder / 'filetred_recording.json'
recording = BaseExtractor.load_extractor_from_json(dump_file)

sorting_KS3 = se.KiloSortSortingExtractor(base_folder / 'output_kilosort3')
waveform_folder = base_folder / 'waveforms_extractor_old_1_job'
if waveform_folder.is_dir():
    shutil.rmtree(waveform_folder)
waveform_folder.mkdir()
sorting_KS3.set_tmp_folder(waveform_folder)

t0 = time.perf_counter()
wf, indexes, channels = st.postprocessing.get_unit_waveforms(recording, sorting_KS3,
            max_spikes_per_unit=500, return_idxs=True, chunk_mb=50, n_jobs=1,
            memmap=True)
t1 = time.perf_counter()
run_time_waveform_old_1jobs = t1 - t0
print('OLD API get_unit_waveforms 1 jobs', run_time_waveform_old_1jobs)
OLD API get_unit_waveforms 1 jobs 513.5964983040467
In [30]:
save_folder = base_folder / 'raw_awake_filtered_old'
dump_file = save_folder / 'filetred_recording.json'
recording = BaseExtractor.load_extractor_from_json(dump_file)

sorting_KS3_bis = se.KiloSortSortingExtractor(base_folder / 'output_kilosort3')
waveform_folder = base_folder / 'waveforms_extractor_old_10_jobs_'
if waveform_folder.is_dir():
    shutil.rmtree(waveform_folder)
waveform_folder.mkdir()
sorting_KS3_bis.set_tmp_folder(waveform_folder)

t0 = time.perf_counter()
wf, indexes, channels = st.postprocessing.get_unit_waveforms(recording, sorting_KS3_bis,
            max_spikes_per_unit=500, return_idxs=True, chunk_mb=500, n_jobs=10,
            memmap=True, verbose=True)
t1 = time.perf_counter()
run_time_waveform_old_10jobs = t1 - t0
print('OLD API get_unit_waveforms 10 jobs', run_time_waveform_old_10jobs)
Number of chunks: 553 - Number of jobs: 10
Impossible to delete temp file: /mnt/data/sam/DataSpikeSorting/eduarda_arthur/waveforms_extractor_old_10_jobs Error [Errno 16] Device or resource busy: '.nfs0000000004ce04d3000007b8'
OLD API get_unit_waveforms 10 jobs 823.8002076600096

Extract waveform with NEW API

The spikeinterface 0.9 API introduce more flexible object WaveformExtractor to do the same (extract snipet).

Here some code example and benchmark speed.

In [39]:
import spikeinterface.extractors as se
from spikeinterface import WaveformExtractor, load_extractor
print('spikeinterface version', si.__version__)

filter_path = base_folder / 'raw_awake_filtered'
filered_recording = load_extractor(filter_path)

sorting_KS3 = se.KiloSortSortingExtractor(base_folder / 'output_kilosort3')
print(sorting_KS3)
spikeinterface version 0.90.0
KiloSortSortingExtractor: 184 units - 1 segments - 30.0kHz
In [41]:
# 1 worker
waveform_folder = base_folder / 'waveforms_extractor_1_job_new_'
if waveform_folder.is_dir():
    shutil.rmtree(waveform_folder)
we = WaveformExtractor.create(filered_recording, sorting_KS3, waveform_folder)

t0 = time.perf_counter()
we.set_params(ms_before=3., ms_after=4., max_spikes_per_unit=500)
we.run(n_jobs=1, total_memory="50M", progress_bar=True)
t1 = time.perf_counter()
run_time_waveform_new_1jobs = t1 - t0
print('New WaveformExtractor 1 jobs',run_time_waveform_new_1jobs)
100%|##########| 278/278 [01:42<00:00,  2.72it/s]
New WaveformExtractor 1 jobs 115.03656197001692
In [42]:
# 1 worker
waveform_folder = base_folder / 'waveforms_extractor_10_job_new_'
if waveform_folder.is_dir():
    shutil.rmtree(waveform_folder)
we = WaveformExtractor.create(filered_recording, sorting_KS3, waveform_folder)

t0 = time.perf_counter()
we.set_params(ms_before=3., ms_after=4., max_spikes_per_unit=500)
we.run(n_jobs=10, total_memory="500M", progress_bar=True)
t1 = time.perf_counter()
run_time_waveform_new_10jobs = t1 - t0
print('New WaveformExtractor 10 jobs', run_time_waveform_new_10jobs)
100%|██████████| 278/278 [00:31<00:00,  8.87it/s]
New WaveformExtractor 10 jobs 48.819815920025576

Conclusion

For filter with 10 workers the speedup is x14.

For waveform extactor with 1 workers the speedup is x4

For waveform extactor with 10 workers the speedup is x16

In [11]:
speedup_filter = run_time_filter_old / run_time_filter_new
print('speedup filter', speedup_filter)
speedup filter 14.635515939778026
In [43]:
speedup_waveform_1jobs = run_time_waveform_old_1jobs / run_time_waveform_new_1jobs
print('speedup waveforms 1 jobs', speedup_waveform_1jobs)

speedup_waveform_10jobs = run_time_waveform_old_10jobs / run_time_waveform_new_10jobs
print('speedup waveformd 10jobs', speedup_waveform_10jobs)
speedup waveforms 1 jobs 4.464637064152789
speedup waveformd 10jobs 16.874299751754943
In [ ]: